package club.blueline.gateway.filter;

import club.blueline.gateway.config.AuthSkipProperty;
import club.blueline.gateway.entity.*;
import club.blueline.gateway.utils.JwtUtils;
import club.blueline.gateway.utils.RedisUtils;
import club.blueline.gateway.utils.ResultUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.SneakyThrows;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.annotation.Resource;
import java.util.*;

@Component
public class AuthFilter implements GlobalFilter, Ordered {

    @Resource
    private AuthSkipProperty authConfig;

    private Set<SkipUrl> skipAuthUrlSet;
    private ObjectMapper objectMapper = new ObjectMapper();

    @Resource
    private RedisUtils redisUtils;

    @SneakyThrows
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        String url = exchange.getRequest().getURI().getPath();
        String method = exchange.getRequest().getMethodValue().toUpperCase();
        SkipUrl skipUrl1 = new SkipUrl(method, url);
        SkipUrl skipUrl2 = new SkipUrl("*", url);
        if(skipAuthUrlSet == null) {
            skipAuthUrlSet = new HashSet<>();
            List<Map<String, String>> authMaps = authConfig.getSkip();
            for(Map<String, String> authMap : authMaps) {
                skipAuthUrlSet.add(new SkipUrl(authMap.get("method").toUpperCase(), authMap.get("url")));
            }
        }
        if(skipAuthUrlSet.contains(skipUrl1) || skipAuthUrlSet.contains(skipUrl2)) {
            return chain.filter(exchange);
        }
        String tokenHeader = exchange.getRequest().getHeaders().getFirst(JwtUtils.TOKEN_HEADER);
        if (tokenHeader == null || !tokenHeader.startsWith(JwtUtils.TOKEN_PREFIX)) {
            // 无Token
            ServerHttpResponse response = exchange.getResponse();
            DataBuffer buffer = genUnauthorizedMessage("未认证的操作！", response);
            return response.writeWith(Flux.just(buffer));
        }
        String token = tokenHeader.replace(JwtUtils.TOKEN_PREFIX, "");
        boolean expiration = JwtUtils.isExpiration(token);
        String username = JwtUtils.getUsername(token);
        Integer userId = JwtUtils.getUserId(token);
        String role = JwtUtils.getUserRole(token);
        Set<ApiAuthority> authorities = null;
        if(expiration) {
            // 过期后检查是否在Redis中
            if(!redisUtils.hasKey("auth:"+username)) {
                ServerHttpResponse response = exchange.getResponse();
                DataBuffer buffer = genUnauthorizedMessage("Token过期，请重新登陆", response);
                return response.writeWith(Flux.just(buffer));
            }
            RedisBean bean = objectMapper.readValue(redisUtils.get("auth:"+username), RedisBean.class);
            if(!bean.getToken().equals(token)) {
                ServerHttpResponse response = exchange.getResponse();
                DataBuffer buffer = genUnauthorizedMessage("Token非法，请重新登陆", response);
                return response.writeWith(Flux.just(buffer));
            } else {
                // Token合法过期，自动续期
                token = JwtUtils.createToken(username, role);
                RedisBean redisBean = new RedisBean(bean.getId(), token, role, bean.getAuthorities());
                redisUtils.set("auth:"+username, objectMapper.writeValueAsString(redisBean), JwtUtils.EXPIRATION * 2);
                ServerHttpResponse response = exchange.getResponse();
                response.getHeaders().add("Content-Type", "application/json;charset=UTF-8");
                response.getHeaders().add("token", JwtUtils.TOKEN_PREFIX + token);
                authorities = bean.getAuthorities();
            }
        }

        // 验证URL和方式
        ApiAuthority tmp1 = new ApiAuthority(exchange.getRequest().getMethodValue(), exchange.getRequest().getURI().getPath());
        ApiAuthority tmp2 = new ApiAuthority("*", exchange.getRequest().getURI().getPath());
        if(authorities.contains(tmp1) || authorities.contains(tmp2)) {
            ServerHttpRequest req = exchange.getRequest().mutate().header("Authorization-id", String.valueOf(userId)).header("Authorization-role", role).build();
            return chain.filter(exchange.mutate().request(req).build());
        } else {
            ServerHttpResponse response = exchange.getResponse();
            DataBuffer buffer = genUnauthorizedMessage("未授权的操作！", response);
            return response.writeWith(Flux.just(buffer));
        }
    }

    private DataBuffer genUnauthorizedMessage(String message, ServerHttpResponse response) {
        response.setStatusCode(HttpStatus.OK);
        response.getHeaders().add("Content-Type", "application/json;charset=UTF-8");
        Result result = ResultUtils.genFailResult(message);
        result.setCode(ResultCode.UNAUTHORIZED.code());
        DataBuffer buffer = response.bufferFactory().wrap(result.toString().getBytes());
        return buffer;
    }

    @Override
    public int getOrder() {
        return -100;
    }
}
